Skip to content

test(aggregation): Improve interactive plotter#641

Merged
ValerianRey merged 25 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/interactive-plotting-ui
Apr 14, 2026
Merged

test(aggregation): Improve interactive plotter#641
ValerianRey merged 25 commits intoSimplexLab:mainfrom
rkhosrowshahi:feature/interactive-plotting-ui

Conversation

@rkhosrowshahi
Copy link
Copy Markdown
Contributor

@rkhosrowshahi rkhosrowshahi commented Apr 13, 2026

Summary

Interactive plotter

  • tests/plots/interactive_plotter.py and tests/plots/_utils.py: factory-based aggregators, checklist selection, improved gradient controls

rkhosrowshahi and others added 21 commits April 9, 2026 10:39
Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator.
Support group_type 0 (whole model), 1 (all_layer via encoder), and 2
(all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable
eps. Add Sphinx page and unit tests. Autogram is not supported; use
torch.manual_seed for reproducible task shuffle order.
- Add GOVERNANCE.md documenting technical governance structure
- Add CODEOWNERS file defining project maintainers
- Add CODE_OF_CONDUCT.md referencing Linux Foundation CoC

These files are required for PyTorch Ecosystem membership.

---------

Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com>
- Use group_type "whole_model" | "all_layer" | "all_matrix" instead of 0/1/2
- Remove DEFAULT_GRADVAC_EPS from the public API; keep default 1e-8; allow eps=0
- Validate beta via setter; tighten GradVac repr/str expectations
- Fix all_layer leaf sizing via children() and parameters() instead of private fields
- Trim redundant GradVac.rst prose; align docs with the new API
- Tests: GradVac cases, value regression with torch.manual_seed for GradVac
- Plotter: factory dict + fresh aggregator instances per update; legend from
  selected keys; MathJax labels and live angle/length readouts in the sidebar

This commit includes GradVac implementation with Aggregator class.
…hting

GradVac only needs gradient norms and dot products, which are fully
determined by the Gramian. This makes GradVac compatible with the autogram path.

- Remove grouping parameters (group_type, encoder, shared_params) from GradVac
- Export GradVacWeighting publicly
Seed is already set to 0 because of the autoused fix_randomness fixture declared in conftest.py
Add a Grouping example page covering all four strategies from the GradVac
paper (whole_model, enc_dec, all_layer, all_matrix), with a runnable code
block for each. Update the GradVac docstring note to link to the new page
instead of the previous placeholder text. Fix trailing whitespace in
CHANGELOG.md.
@PierreQuinton
Copy link
Copy Markdown
Contributor

Ideally, this should be split in 3 PRs, in order:

  • grouping.rst
  • dradvac
  • interactive plotter

@ValerianRey
Copy link
Copy Markdown
Contributor

Ideally, this should be split in 3 PRs, in order:

  • grouping.rst
  • dradvac
  • interactive plotter

I merged main into this and removed grouping example (59e7942), so now this only does the improvement of the plotter. I'll also update the PR title and message accordingly.

@ValerianRey ValerianRey added cc: test Conventional commit type for changes to tests. labels Apr 14, 2026
@github-actions github-actions bot changed the title feat(aggregation): GradVac and enhanced interactive plotter test(aggregation): GradVac and enhanced interactive plotter Apr 14, 2026
@ValerianRey ValerianRey changed the title test(aggregation): GradVac and enhanced interactive plotter test(aggregation): Improve interactive plotter Apr 14, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan please review this.

@opencode-agent
Copy link
Copy Markdown
Contributor

Now I have both versions. Let me analyze the changes:

Code Review: Interactive Plotter Improvements

Summary of Changes

tests/plots/_utils.py

Key Changes:

  1. Factory pattern for aggregators: Changed from storing aggregator instances to storing factory functions (Callable[[], Aggregator])
  2. Key-based selection: Uses selected_keys to track which aggregators to display
  3. Dynamic instantiation: Aggregators are created fresh each time make_fig() is called
  4. Added gradient vector text positioning: Now explicitly positions labels at "top right"

Benefits:

  • Ensures fresh aggregator state for each plot (important for stateful aggregators like Random, GradDrop)
  • Cleaner separation between available aggregators and selected ones

tests/plots/interactive_plotter.py

Key Changes:

  1. Factory-based aggregator definitions: Each aggregator is now a lambda that creates a fresh instance
  2. Added helper functions: _format_angle_display() and _format_length_display() for human-readable values
  3. Improved gradient controls:
    • Added formatted value displays showing both radians and degrees for angles
    • Better styling with labels and monospace font for values
    • Uses MathJax for gradient labels (e.g., $g_{1}$)
  4. Updated callback signatures: update_gradient_coordinate now also updates the displayed values

Issues Found

1. Missing Unpack import handling

The code imports Unpack from typing_extensions but this is only available in Python 3.11+. For older Python versions, this could cause issues.

Location: interactive_plotter.py, line 11

2. Unused variable in loop

In update_gradient_coordinate, the display_parts list is built but could be simplified.

Current code:

display_parts: list[str] = []
for j in range(len(values_) // 2):
    ...
    display_parts.append(_format_angle_display(angle))
    display_parts.append(_format_length_display(r))
return (plotter.make_fig(), *display_parts)

This works but is slightly verbose. Not a blocker.

3. Potential issue with TrimmedMean(trim_number=1)

With the default 3 tasks, TrimmedMean(trim_number=1) will trim 1 from each side (2 total), leaving only 1 gradient. This might produce unexpected results but is acceptable for testing.

Positive Aspects

  1. Factory pattern is well-implemented: Solves the stateful aggregator issue properly
  2. UI improvements are excellent: The formatted displays and better styling significantly improve usability
  3. Clean code structure: The helper functions and styling dictionaries are well-organized
  4. MathJax integration: Using LaTeX formatting for gradient labels is a nice touch

Recommendations

  1. Consider adding a comment about why the factory pattern is used (stateful aggregators):
# Use factory functions to ensure fresh aggregator state for each plot
# (important for stateful aggregators like Random, GradDrop)
  1. The typing_extensions dependency - Verify this is already in the project's dependencies since Unpack is used.

Overall Assessment

Approve - This is a well-structured improvement that:

  • Properly addresses state management for aggregators
  • Significantly improves the UI/UX of the interactive plotter
  • Follows the existing code style
  • Makes the testing tool more useful and user-friendly

The changes are focused, minimal, and achieve the stated goals of factory-based aggregators, checklist selection, and improved gradient controls.

New%20session%20-%202026-04-14T14%3A12%3A47.842Z
opencode session  |  github run

@ValerianRey
Copy link
Copy Markdown
Contributor

The code imports Unpack from typing_extensions but this is only available in Python 3.11+. For older Python versions, this could cause issues.

The interactive plotter is only for torchjd devs (not people who pip install the library), so it's fine if it doesn't work with python 3.10. So IMO this is not an issue.

Copy link
Copy Markdown
Contributor

@ValerianRey ValerianRey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice. Few nitpicks and we can merge.

Comment on lines +59 to +61
"AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"),
"AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"),
"AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would be nice to change the __str__ method of AlignedMTL to also include the scale_mode. For example, str(AlignedMTL(scale_mode="min")) would become AlignedMTL-min.

But this is quite independent from this PR, and should come in a different PR if you care about that! @rkhosrowshahi

@ValerianRey
Copy link
Copy Markdown
Contributor

We could make another PR that adds the possibility to change the number of gradients. A few things to keep in mind:

  • Some aggregators like NashMTL will have to be re-created based on the number of selected gradients (so that the num_tasks parameters can match that).
  • Some aggregators like TrimmedMean require at least 3 gradients, and there's probably some aggregators that require at least 2. They would have to be deactivated when the number of gradients is not enough.

So it's not trivial, but if you're interested in doing that, we could land the PR quite easily @rkhosrowshahi

@ValerianRey ValerianRey merged commit 86fe403 into SimplexLab:main Apr 14, 2026
14 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: test Conventional commit type for changes to tests. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants